Essential question: What is Hamiltanion Monte Carlo (HMC) and how do we use it to do Bayesian inference?

Start-of-class work (5 - 10 min)

Watch a Metropolis algorithm “get stuck” on the “donut” example. Discuss why this may be occurring.

mcmc demo

I. Why HMC? (Ch.20; 5 - 10 min)

Overview of MCMC Strategies

  • Metropolis: Granddaddy of them all
  • Metropolis-Hastings (MH): More general
  • Gibbs sampling (GS): Efficient version of MH
  • Metropolis and Gibbs are “guess and check” strategies
  • Hamiltonian Monte Carlo (HMC) fundamentally different - uses the gradient
  • New methods being developed, but future belongs to the gradient

Excellent reference text: Brooks, S., Gelman, A., Jones, G., & Meng, X. L. (Eds.). (2011). Handbook of markov chain monte carlo. CRC press.

Problem with Gibbs sampling (GS)

  • High dimension spaces are concentrated
  • GS gets stuck, degenerates towards random walk
  • Inefficient because re-explores

Hamiltonian dynamics to the rescue

  • represent parameter state as particle
  • flick it around frictionless log-posterior
  • record positions
  • no more “guess and check”, as (nearly) all proposals are good proposals

II. Brief overview of HMC (5 - 10 min)

From MacKay, David JC. Information theory, inference and learning algorithms. Cambridge university press, 2003.

“The HMC method is a Metropolis method, applicable to continous state spaces, that makes use of gradient information to reduce random walk behavior… It seems wasteful to use a simple random-walk Metropolis method when this gradient is avaiable - the gradient indicates which direction one should go in to find states that have higher probability!”

See MacKay, Section 30.1 for a basic implementation of HMC and matematical details. Let’s return to the MCMC demonstration and I’ll dicuss the intuition.

  • Standard normal (“hill”)
  • Donut (like a “bowl” that follows that shape).

mcmc demo

More on HMC

III. Using software to do HMC (25 - 30 min)

Stan

  • No U-Turn Sampler (NUTS2): Adaptive Hamiltonian Monte Carlo
  • Implemented in Stan (rstan: mc-stan.org)
  • Stan figures out gradient for you via autodiff

Problem with regular HMC U-turns

  • Increase leapfrog steps in regular HMC on normal density
  • Donut (like a “bowl” that follows that shape.

mcmc demo

Advantages of Stan

Derived from

Gelman, Andrew, Daniel Lee, and Jiqiang Guo. “Stan: A probabilistic programming language for Bayesian inference and optimization.” Journal of Educational and Behavioral Statistics 40.5 (2015): 530-543.

“Stan was motivated by the desire to solve problems that could not be solved in reasonable time (user programming time plus run time) using other packages.

In comparing Stan to other software options, we consider several criteria: 1. Flexibility, that is, being able to fit the desired model. 2. Ease of use; user programming time. 3. Run time. 4. Scalability as dataset and model grow larger."

A quick-walkthrough of a detailed example of rstan for meta-analysis

http://mc-stan.org/rstan/articles/rstan.html

IV. Check the chains: diagnostics and a way to fix (10 - 15 min)

Sometimes it doesn’t work

  • Good chains
    • Converge to same target distribution
    • Once there, explore efficently
  • Different ways to check
    • Trace plots
    • Convergence diagnostics (n_eff , Rhat)
    • Special warnings (divergent transitions)

Trace plot

  • Check first
  • Shows some problems not all
  • want to see a “hairy caterpillar”

Convergence diagnostics

  • n_eff: “effective” number of samples
    • n_eff / n < 0.1, be alarmed
  • R-hat
    • R-hat: crudely, ratio of variance between chains to variance within chains
    • Should approach 1
  • Both diagnostics can mislead

A wild chain

library(rethinking)
y <- c(-1, 1)
set.seed(11)
m9.2 <- ulam(
    alist(
        y ~ dnorm( mu, sigma ),
        mu <- alpha,
        alpha ~ dnorm(0, 1000) ,
        sigma ~ dexp( 0.0001)
    ),
    data = list(y = y), chains = 2)
## Cautionary note:
## Variable y contains only integers but is not type 'integer'. If you intend it as an index variable, you should as.integer() it before passing to ulam.
## recompiling to avoid crashing R session
## 
## SAMPLING FOR MODEL '726d002e27cec1633082261fcfedb813' NOW (CHAIN 1).
## Chain 1: 
## Chain 1: Gradient evaluation took 7e-06 seconds
## Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 0.07 seconds.
## Chain 1: Adjust your expectations accordingly!
## Chain 1: 
## Chain 1: 
## Chain 1: Iteration:   1 / 1000 [  0%]  (Warmup)
## Chain 1: Iteration: 100 / 1000 [ 10%]  (Warmup)
## Chain 1: Iteration: 200 / 1000 [ 20%]  (Warmup)
## Chain 1: Iteration: 300 / 1000 [ 30%]  (Warmup)
## Chain 1: Iteration: 400 / 1000 [ 40%]  (Warmup)
## Chain 1: Iteration: 500 / 1000 [ 50%]  (Warmup)
## Chain 1: Iteration: 501 / 1000 [ 50%]  (Sampling)
## Chain 1: Iteration: 600 / 1000 [ 60%]  (Sampling)
## Chain 1: Iteration: 700 / 1000 [ 70%]  (Sampling)
## Chain 1: Iteration: 800 / 1000 [ 80%]  (Sampling)
## Chain 1: Iteration: 900 / 1000 [ 90%]  (Sampling)
## Chain 1: Iteration: 1000 / 1000 [100%]  (Sampling)
## Chain 1: 
## Chain 1:  Elapsed Time: 0.031784 seconds (Warm-up)
## Chain 1:                0.006581 seconds (Sampling)
## Chain 1:                0.038365 seconds (Total)
## Chain 1: 
## 
## SAMPLING FOR MODEL '726d002e27cec1633082261fcfedb813' NOW (CHAIN 2).
## Chain 2: 
## Chain 2: Gradient evaluation took 2e-06 seconds
## Chain 2: 1000 transitions using 10 leapfrog steps per transition would take 0.02 seconds.
## Chain 2: Adjust your expectations accordingly!
## Chain 2: 
## Chain 2: 
## Chain 2: Iteration:   1 / 1000 [  0%]  (Warmup)
## Chain 2: Iteration: 100 / 1000 [ 10%]  (Warmup)
## Chain 2: Iteration: 200 / 1000 [ 20%]  (Warmup)
## Chain 2: Iteration: 300 / 1000 [ 30%]  (Warmup)
## Chain 2: Iteration: 400 / 1000 [ 40%]  (Warmup)
## Chain 2: Iteration: 500 / 1000 [ 50%]  (Warmup)
## Chain 2: Iteration: 501 / 1000 [ 50%]  (Sampling)
## Chain 2: Iteration: 600 / 1000 [ 60%]  (Sampling)
## Chain 2: Iteration: 700 / 1000 [ 70%]  (Sampling)
## Chain 2: Iteration: 800 / 1000 [ 80%]  (Sampling)
## Chain 2: Iteration: 900 / 1000 [ 90%]  (Sampling)
## Chain 2: Iteration: 1000 / 1000 [100%]  (Sampling)
## Chain 2: 
## Chain 2:  Elapsed Time: 0.040616 seconds (Warm-up)
## Chain 2:                0.047865 seconds (Sampling)
## Chain 2:                0.088481 seconds (Total)
## Chain 2:
## Warning: There were 207 divergent transitions after warmup. Increasing adapt_delta above 0.95 may help. See
## http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
## Warning: Examine the pairs() plot to diagnose sampling problems
precis( m9.2 )
##             mean       sd        5.5%     94.5%    n_eff     Rhat
## alpha  -9.806072 134.6088 -171.409992  98.38377 167.3113 1.001382
## sigma 162.573888 514.4470    2.587262 827.92693  73.7844 1.020506
traceplot( m9.2 )

  • The problem is (nearly) flat priors that encourages the sampler to explore a log-posterior out to the thousands…
  • This is one reason Maximum Likelihood Estimation can be a problematic!
  • Also a problem for Gibbs.
  • Weakly informative priors to the rescue

A tame chain

set.seed(11)
m9.3 <- ulam(
    alist(
        y ~ dnorm( mu, sigma ),
        mu <- alpha,
        ## even include a "bad" starting point
        alpha ~ dnorm(1, 10) ,
        sigma ~ dexp( 1)
    ),
    data = list(y = y), chains = 2)
## Cautionary note:
## Variable y contains only integers but is not type 'integer'. If you intend it as an index variable, you should as.integer() it before passing to ulam.
## recompiling to avoid crashing R session
## 
## SAMPLING FOR MODEL 'db8b93ccfa83872ce482c35ebed2c618' NOW (CHAIN 1).
## Chain 1: 
## Chain 1: Gradient evaluation took 1.1e-05 seconds
## Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 0.11 seconds.
## Chain 1: Adjust your expectations accordingly!
## Chain 1: 
## Chain 1: 
## Chain 1: Iteration:   1 / 1000 [  0%]  (Warmup)
## Chain 1: Iteration: 100 / 1000 [ 10%]  (Warmup)
## Chain 1: Iteration: 200 / 1000 [ 20%]  (Warmup)
## Chain 1: Iteration: 300 / 1000 [ 30%]  (Warmup)
## Chain 1: Iteration: 400 / 1000 [ 40%]  (Warmup)
## Chain 1: Iteration: 500 / 1000 [ 50%]  (Warmup)
## Chain 1: Iteration: 501 / 1000 [ 50%]  (Sampling)
## Chain 1: Iteration: 600 / 1000 [ 60%]  (Sampling)
## Chain 1: Iteration: 700 / 1000 [ 70%]  (Sampling)
## Chain 1: Iteration: 800 / 1000 [ 80%]  (Sampling)
## Chain 1: Iteration: 900 / 1000 [ 90%]  (Sampling)
## Chain 1: Iteration: 1000 / 1000 [100%]  (Sampling)
## Chain 1: 
## Chain 1:  Elapsed Time: 0.009138 seconds (Warm-up)
## Chain 1:                0.008345 seconds (Sampling)
## Chain 1:                0.017483 seconds (Total)
## Chain 1: 
## 
## SAMPLING FOR MODEL 'db8b93ccfa83872ce482c35ebed2c618' NOW (CHAIN 2).
## Chain 2: 
## Chain 2: Gradient evaluation took 3e-06 seconds
## Chain 2: 1000 transitions using 10 leapfrog steps per transition would take 0.03 seconds.
## Chain 2: Adjust your expectations accordingly!
## Chain 2: 
## Chain 2: 
## Chain 2: Iteration:   1 / 1000 [  0%]  (Warmup)
## Chain 2: Iteration: 100 / 1000 [ 10%]  (Warmup)
## Chain 2: Iteration: 200 / 1000 [ 20%]  (Warmup)
## Chain 2: Iteration: 300 / 1000 [ 30%]  (Warmup)
## Chain 2: Iteration: 400 / 1000 [ 40%]  (Warmup)
## Chain 2: Iteration: 500 / 1000 [ 50%]  (Warmup)
## Chain 2: Iteration: 501 / 1000 [ 50%]  (Sampling)
## Chain 2: Iteration: 600 / 1000 [ 60%]  (Sampling)
## Chain 2: Iteration: 700 / 1000 [ 70%]  (Sampling)
## Chain 2: Iteration: 800 / 1000 [ 80%]  (Sampling)
## Chain 2: Iteration: 900 / 1000 [ 90%]  (Sampling)
## Chain 2: Iteration: 1000 / 1000 [100%]  (Sampling)
## Chain 2: 
## Chain 2:  Elapsed Time: 0.009027 seconds (Warm-up)
## Chain 2:                0.007937 seconds (Sampling)
## Chain 2:                0.016964 seconds (Total)
## Chain 2:
precis( m9.3 )
##             mean       sd       5.5%    94.5%    n_eff     Rhat
## alpha 0.04516955 1.253826 -1.9903853 1.801000 213.1396 1.008117
## sigma 1.55272619 0.794507  0.7041972 3.097116 265.3611 1.013683
traceplot( m9.3 )

“Folk Theorem” of statistical computing

Folk theorem

Folk theorem

Online lectures

Closing (5 - 10 min)

In what situations do we need to use more advanced and efficient MC algorithms?